import ekf_testv6 as ekf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as ss
plt.rcParams["font.family"] = "Helvetica"
import summary_of_sims as sum_sims
from statsmodels.graphics import tsaplots
import matplotlib.transforms as mtransforms

matsh=9
#order: running 30-year mean, OCN, Butterworth, change point, EKF, RTS, coupled sims average, blind model
#order to simply have a smaller range for things that don't apply

labels=["Moving Average $\overline{_{30}Y_n}$", "OCN Chunks", "Butterworth Smoothed", "Change Point Lines", "Ext. Kalman Filter $\\hat {T }_n$", "RTS $ \\hat \hat {T }_n$", \
        "LENS2 $\overline{(Y_n)_j}$", "Blind Model $\~{T}_{n}$"]
shlabels=["RunAvg \n $\overline{_{30}Y_n}$", "OCN \n Chunks", "ButW \n Smooth", "$\\Delta$Pt \n Lines", "EKF \n $\\hat {T }_n$", "RTS \n $ \\hat \hat {T }_n$", \
        "LENS2 \n $\overline{(Y_n)_j}$", "Blind \n $\~{T}_{n}$", "HadCRUT5 \n ${Y}_{n}$"]
colorlabels=["mediumseagreen","darkgreen","springgreen", ekf.pcolor, ekf.colorekf, ekf.colorrts ,sum_sims.cesmcolor,"darkgoldenrod",ekf.colorgrey]
oneswt=np.ones(ekf.n_iters)
values = np.empty((matsh,ekf.n_iters))
values[:] = np.nan
stdevs= np.empty((matsh-2,ekf.n_iters))
stdevs[:] = np.nan
moderrs=np.empty((matsh-2,ekf.n_iters))
moderrs[:] = np.nan
temps=ekf.temps

#compute running 30 yr mean
print(len(ekf.temps)==ekf.n_iters)

N = 30
moving_aves=np.empty(len(ekf.temps))
moving_aves[:] = np.nan
std_aves=np.empty(len(ekf.temps))
std_aves[:] = np.nan
cN=int(np.ceil(N/2));
fN=int(np.floor(N/2));

for i in range(cN,(len(ekf.temps)-fN+1)):
    lasta=i+fN+1;
    firsta=i-cN;
    moving_aves[i] = np.mean(ekf.temps[firsta:lasta]);
    std_aves[i]= np.std(ekf.temps[firsta:lasta]);

values[0,:]=moving_aves
stdevs[0,:]=std_aves
moderrs[0,:]=std_aves/np.sqrt(N)

#OCN
dataOCN = np.genfromtxt(open("Mann_Smoothing_08/OCNtemperatures.csv", "rb"),dtype=float, delimiter=',')
lenOCN=len(dataOCN[1,:])
values[1,:lenOCN]=dataOCN[1,:]
stdevs[1,:lenOCN]=dataOCN[3,:]/2
moderrs[1,:lenOCN]=dataOCN[2,:]

#butterworth
dataBU = np.genfromtxt(open("Mann_Smoothing_08/BUStemperatures.csv", "rb"),dtype=float, delimiter=',')
values[2,:]=dataBU[1,:-1]
for i in range(cN,(len(ekf.temps)-fN+1-1)):
    lasta=i+fN+1;
    firsta=i-cN;
    stdevs[2,i] = 2*np.mean(np.abs(dataBU[3,firsta:lasta]));
    moderrs[2,i]= 2*np.mean(dataBU[2,firsta:lasta]);

#change point
dataCP = np.genfromtxt(open("Bayes_Sequential_Change_Point/BSCtemperatures.csv", "rb"),dtype=float, delimiter=',')
values[3,:]=dataCP[1,:-1]
stdevs[3,:]=dataCP[3,:-1]/2
moderrs[3,:]=dataCP[2,:-1]

#ekf
values[4,:]=ekf.xh1s
stdevs[4,:]=ekf.stdS
moderrs[4,:]=ekf.stdP
#stdevs[4,0:2]=np.nan

#rts
values[5,:]=ekf.xhh1s
stdevs[5,:]=ekf.stdSh
moderrs[5,:]=ekf.stdPh
stdevs[5,-1]=np.nan
#stdevs[5,0:1]=np.nan

#simulations
#lenss=sum_sims.smyrs
lenss=173
values[6,:lenss]=sum_sims.twTRmean[:lenss]
stdevs[6,:lenss]=sum_sims.twTRstd[:lenss]
lenss=165
CMIP6=np.genfromtxt(open("summaryCMIP6.csv", "rb"),dtype=float, delimiter=',')
moderrs[6,:lenss]=CMIP6[20,:]
#blind model
values[7,:]=ekf.xblind[:,0]

values[8,:]=temps

#absolute difference from real temps
figh,(axl,axp)=plt.subplots(1,2,figsize=(12,4))
figh.subplots_adjust(wspace=0.4)
diffptsEKF=-values[4,:-1]+temps[0:-1]
derivptsEKF=values[4,1:]-values[4,:-1]

diffpts30m=-values[0,15:-16]+temps[15:-16]
derivpts30m= values[0,16:-15]-values[0,15:-16]
axp.plot(diffptsEKF,derivptsEKF,'.',color=colorlabels[4],label="EBM-KF",zorder=2)
axp.plot(diffpts30m,derivpts30m,'.',color=colorlabels[0],label="30-year running mean",zorder=2)
axp.set_xlabel("Innovation from Mean/State to Measurement")
axp.set_ylabel("Change in Mean or State")
axp.set_title("Comparison of 30-year mean and EBM-KF Distributions")

legend1=plt.legend(loc="lower right")

noteYears = [[1991,1992,1993,1994,1995,1996,1997,1998],[1963,1964,1965],[1902,1903,1904],[1883,1884,1885],
		[1856,1857,1858]]
noteVolcs = ["Mt. Pinatubo, Philippines", "Mt. Agung, Indonesia","Santa Maria, Guatemala", "Krakatoa, Indonesia", \
              "? Mt. Awu, Indonesia ?"] # "Komaga-take, Japan" "Shiveluch, Russia"
symbols = ['^', '>', 's','P','*','X']
volplots=[]
for i in range(len(noteVolcs)):
    ds=np.array(noteYears[i])-1850-1
    volplot,=axp.plot(diffptsEKF[ds],derivptsEKF[ds],marker=symbols[i],linestyle='None',color='0.7',zorder=0)
    volplots.append(volplot)
    #axp.plot(diffpts30m[ds-15],derivpts30m[ds-15],marker=symbols[i],linestyle='None',color='k',zorder=0)
    for d in ds:
        if(d+1850+1==1902):
            xshf=-5
        elif(d+1850+1==1963):
            xshf=-10
        elif(d+1850+1==1998):
            xshf=3
        else:
            xshf=0
        plt.annotate((d+1850+1), (diffptsEKF[d],derivptsEKF[d]),fontsize=8,font='Arial Narrow',
                 textcoords="offset points", xytext=(xshf,-3+6*np.sign(derivptsEKF[d])), ha='center')
       # plt.annotate((d), (diffpts30m[d-15],derivpts30m[d-14]),
       #          textcoords="offset points", xytext=(0,0.01), ha='center')
    plt.annotate("1850 - initalized", (diffptsEKF[0],derivptsEKF[0]),fontsize=8,font='Arial Narrow',
                 textcoords="offset points", xytext=(0,-3-6), ha='center')
legend2 = axp.legend(volplots,noteVolcs,loc="center right",fontsize=8)
axp.add_artist(legend1)
axp.set_ylim(-0.25,0.11)
#print("MAX INDEX")
#print(np.argmax(diffpts30m))
sortedderivs = sorted(zip(ekf.dates[1:],derivptsEKF), key=lambda x:x[1])
print(sortedderivs[0:20])
r2 = ekf.r2_score(values[0,15:-16], values[4,15:-16])
print('r2 score for running mean is to EBM-KF', r2)
r2 = ekf.r2_score(values[0,15:-16], values[7,15:-16])
print('r2 score for running mean is to blind', r2)
r2 = ekf.r2_score(values[6,:], values[4,:])
print('r2 score for LENS2 ens to EBM-KF', r2)
r2 = ekf.r2_score(values[7,:], values[4,:])
print('r2 score for blind to EBM-KF', r2)


axlabels=['a)','b)','c)']

axs=(axl,axp)
for i in range(len(axs)):
    ax=axs[i]
    label=axlabels[i]
# label physical distance to the left and up:
    trans = mtransforms.ScaledTranslation(-40/72, 1/72, figh.dpi_scale_trans) #-20/72, 7/72
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
        fontsize='large', va='bottom') #, fontfamily='serif')



#graparr=[[4,0,1,2],[4,3,6],[4,5,7]]
gs_bot = plt.GridSpec(matsh, 4, wspace=0.3, top=0.8)
gs_base = plt.GridSpec(matsh, 4, wspace=0.3, hspace=0.1)
fig = plt.figure(figsize=(12,10))
ax0s=[]
for j in range(0,4):
    botax = fig.add_subplot(gs_bot[matsh-1,j])
    other_axes = [fig.add_subplot(gs_base[i,j]) for i in range(0, matsh-1)]
    ax0s.append(other_axes+[botax])
axs=[list(x) for x in zip(*ax0s)]
#print(axs)
#fig, axs = plt.subplots(matsh, 4, figsize=(12,10))
#fig.subplots_adjust(wspace=0.3, hspace=0.2)
fig.suptitle("Histogram Comparisons of Smoothing Methods on GMST\n (Bar Height Represents Fraction of Timepoints)")
for d in range(0,matsh):
    
    ax = axs[d][0]
    if d==0:
        ax.set_title("Real HadCRUT5\n Measurements\n Minus Value")
    if d==8:
        ax.set_xticklabels([])
    if d==7:
        ax.set_xlabel("Temperature Difference (K)")
    elif d<7:
        ax.set_xticklabels([])
        
     
    if d<=7:
        diff=-values[d,:]+temps
        count=np.count_nonzero(~np.isnan(diff))
        ax.hist(diff,weights=oneswt/count,bins=np.linspace(-0.5,0.5,50),alpha=1, color=colorlabels[d])
    #ax.legend(loc='best',prop={"size":8})
        ax.set_ylim(0.0,0.14)
        ax.set_xlim(-0.4,0.4)
        ax.tick_params( direction = 'out',bottom=True, top=False, left=True, right=True )

    else:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.tick_params( direction = 'out',bottom=False, top=False, left=False, right=False)
        ax.set_zorder(-1)
        ax.set_yticklabels([])

#    if d<8:
#        ax.set_ylabel(shlabels[d]) 
    ax.set_ylabel(shlabels[d])
    
    ax =axs[d][1] 
    if d==0:
        ax.set_title("Yearly Value Change")
    if d==8:
        ax.set_xlabel("Temperature Change (K/year)")
    elif d!=7:
        ax.set_xticklabels([])
    #else:
       # ax.set_xticklabels([])

    diff=values[d,1:]-values[d,:-1]
    count=np.count_nonzero(~np.isnan(diff))
    ax.hist(diff,weights=oneswt[1:]/count,bins=np.linspace(-0.15,0.15,150),alpha=1, color=colorlabels[d])
    #ax.legend(loc='best',prop={"size":8})ck
    if d==1:
        ax.set_ylim(0.0,0.11) #ocn chunks too high, almost all are 0
        ax.text(0.002,0.08,"*")
    ax.set_xlim(-0.065,0.065)
    ax.set_xticks(np.arange(-0.06, 0.07, 0.03))
    if d==8:
        #ax.set_ylim(0.0,0.07)
        ax.set_xlim(-0.25,0.25)
        ax.set_xticks(np.arange(-0.2, 0.25, 0.1))
        ax.hist(diff,weights=oneswt[1:]/count,bins=np.linspace(-0.35,0.35,50),alpha=1, color=colorlabels[d])
        
    
    ax.tick_params( direction = 'out',bottom=True, top=False, left=True, right=True )
#    if d==8:
#        ax.set_ylabel(shlabels[d]) 

    ax2 =axs[d][2]
    if d!=7:
        tsaplots.plot_acf(values[d,:]-values[7,:], lags=30, ax=ax2, missing='drop', title='')
        ax2.set_ylim(-0.9,1.1)
    else:
        ax2.spines['right'].set_visible(False)
        ax2.spines['top'].set_visible(False)
        ax2.spines['bottom'].set_visible(False)
        ax2.spines['left'].set_visible(False)
        ax2.tick_params( direction = 'out',bottom=False, top=False, left=False, right=False)
        ax2.set_zorder(-1)
        ax2.set_yticklabels([])
    if d==0:
        ax2.set_title("Autocorrelation of\n Difference from Blind Model")
	
    if  (d==8 or d==6):
        ax2.set_xlabel("Lag Time (years)")
    else:
        ax2.set_xticklabels([])

    ax3=axs[d][3]
    ax3.set_ylim(0,35)
    ranflickerb=32;
    ranflicker=8;
    stepflicker=8;
    ax3.set_xticks(np.arange(-ranflickerb, ranflicker+stepflicker, step=stepflicker)) 
    
    if d==0:
        ax3.set_title("Crossing Times\n with Real HadCRUT5")
    if d==7:
        ax3.set_xlabel("String of Years\n Above(+) or Below(-)\n all years given equal weight")
    else:
        ax3.set_xticklabels([])

    if d<8:
        nonnans=~np.isnan(values[d,:])
        gted="".join(((np.greater_equal(values[d,nonnans],values[8,nonnans])).astype(int)).astype(str))
        negs=np.fromiter(map(len,gted.split('1')),dtype=int)*(-1)
        posis=np.fromiter(map(len,gted.split('0')),dtype=int)
        histo=np.concatenate((negs,posis))
        ax3.hist(histo,weights=np.abs(histo),alpha=1, bins=(ranflicker+ranflickerb)*2+1,
                 range=(-ranflickerb-0.25,ranflicker+0.25),color=colorlabels[d])
    else:
        ax3.spines['right'].set_visible(False)
        ax3.spines['top'].set_visible(False)
        ax3.spines['bottom'].set_visible(False)
        ax3.spines['left'].set_visible(False)
        ax3.tick_params( direction = 'out',bottom=False, top=False, left=False, right=False)
        ax3.set_zorder(-1)
        ax3.set_yticklabels([])
	


axl.set_title('Comparison of Temperature Means')

toplot=[0,4,6]
for i in toplot:
    axl.plot(ekf.dates,values[i,:]-ekf.pindavg,'-',label=labels[i], color=colorlabels[i])
axl.legend()
ekf.plot_boilerplate(axl)
for v in noteYears:
    axl.plot([v[0],v[0]],[-0.5,1.7],':',color='0.7',zorder=0) #[286,289]

axl.set_yticks(np.arange(-0.4,1.8,0.2))
axl.set_ylabel('$\\Delta$ Temperature (°C)\nAbs(K)=$\\Delta(T)$ + 286.7K')
axl.set_ylim(286.3-ekf.pindavg,288.1-ekf.pindavg)
axb = axl.twinx()
axb.set_yticks(np.arange(12.4,18.4,0.2))
axb.set_ylim(286.3- 273.15, 288.1- 273.15)
axb.set_ylabel('Temperature (°C)')

startpts=[15,0,  15, 0,2,  2,  0]
endpts=[-15,-2, -15,-1,-1,-2,-13]
order=[2,6,0,1,3,5,4]

plt.savefig("compare_hists1.pdf", format="pdf")    
plt.savefig("compare_hists1.png", dpi=400,format="png")

plt.figure()
for ord in range(0,7):
    d=order[ord]
    label2=shlabels[d]
    label2=label2.replace('\n','')
    if ord<=4:
        plt.plot(stdevs[d,:],moderrs[d,:],'-o',color=colorlabels[d], label=label2, alpha=0.7, markersize=2)
    else:
        plt.plot(stdevs[d,:],moderrs[d,:],'-o',color=colorlabels[d], label=label2, alpha=1, markersize=3, linewidth=2)
    plt.plot(stdevs[d,startpts[d]],moderrs[d,startpts[d]],'^',color=colorlabels[d], markersize=5 ,markeredgewidth=.5, markeredgecolor=(0,0,0, 1))
    plt.plot(stdevs[d,endpts[d]],moderrs[d,endpts[d]],'s',color=colorlabels[d], markersize=5 ,markeredgewidth=.5, markeredgecolor=(0,0,0, 1))

order2=[2,3,0,4,6,5,1]
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend([handles[idx] for idx in order2],[labels[idx] for idx in order2])
plt.title("Comparison of Modeled GMST Variabilities")
plt.ylabel("State Uncertainty $\sqrt{P_{n}}$ or Std Error in K")
plt.xlabel("Prediction / Innovation Uncertainty $\sqrt{S_{n}}$ or Std Dev in K")
plt.ylim([0,0.15])
plt.xlim([0,0.23])
plt.xticks(np.arange(0, .24, 0.02))

plt.savefig("compare_hists2.pdf", format="pdf")    
plt.savefig("compare_hists2.png", dpi=400,format="png")

figh.savefig("compare_means.pdf", format="pdf")    
figh.savefig("compare_means.png", dpi=400,format="png")


plt.show()





